iT邦幫忙

2021 iThome 鐵人賽

DAY 13
0
AI & Data

30 天在 Colab 嘗試的 30 個影像分類訓練實驗系列 第 13

【13】模型套不套用資料增強 (Data Augmentation) 的比較實驗

  • 分享至 

  • xImage
  •  

Colab連結

資料增強(Data Augmentation),是一個當今天資料集樣本不多時,透過調整亮度、剪裁、角度等手法來增加多樣性的好方法,Tensorflow 的 tf.image.random_* API 提供了不少資料增強的方法,讓我們在訓練模型時可以使用。

這次我簡單介紹幾個 API 並看看,這幾種 Augmentation 方式會產生什麼樣的效果。

def aug_img(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, (224,224))
  flip_image = tf.image.random_flip_left_right(image)
  flip_image = tf.image.random_flip_up_down(flip_image)
  brt_img = tf.image.random_brightness(flip_image, 70)
  brt_img = tf.clip_by_value(brt_img, clip_value_min=0.0, clip_value_max=255.0)
  sat_img = tf.image.random_saturation(brt_img, 0.7, 1.5)
  sat_img = tf.clip_by_value(sat_img, clip_value_min=0.0, clip_value_max=255.0)
  cts_img = tf.image.random_contrast(sat_img, 0.6, 1.4)
  cts_img = tf.clip_by_value(cts_img, clip_value_min=0.0, clip_value_max=255.0)
  return image, flip_image, brt_img, sat_img, cts_img

random_flip:

就是隨機上下左右顛倒,像這次的資料集是花的辨識,花本身就沒有一定的方向性,就很適合拿來使用,但如果今天的資料集是貓狗二分類,那麼只需要左右顛倒即可。

random_brightness:

提供一個 max_delta 的值,會將圖片每個像數乘上這個的變化量,要注意的是,如果今天你的圖片已經先 normalize 到 [0.0, 1.0] 之間了,那這個值可以指設0.1就會產生很大的亮度差異,但如果今天圖片的範圍是[0, 255],那就需要設定比如70這樣大的數值去產生亮度差異。

random_saturation:

提供上限 upper 和下限 lower 來決定圖片的飽和度。

random_contrast:

和 random_saturation 雷同,對圖片隨機的對比度。

我們印出實際的圖片變化

原圖:

https://ithelp.ithome.com.tw/upload/images/20210927/20107299Z3TxMI3SPn.png

隨機顛倒:

https://ithelp.ithome.com.tw/upload/images/20210927/20107299zmWlxB8tMV.png

隨機亮度:

https://ithelp.ithome.com.tw/upload/images/20210927/20107299twNVtIVCV6.png

隨機飽和度:

https://ithelp.ithome.com.tw/upload/images/20210927/20107299ESpGBwCQgC.png

隨機對比度:

https://ithelp.ithome.com.tw/upload/images/20210927/20107299b13KnRDqjE.png

檢查完圖片都該有的變化後,我們先跑一次不做任何資料增強的訓練:

base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)

model = tf.keras.Model(inputs=[base.input], outputs=[net])

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=True)

產出:

最佳成績:

loss: 5.2436e-04 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.5029 - val_sparse_categorical_accuracy: 0.8706

https://ithelp.ithome.com.tw/upload/images/20210927/20107299qb63bXaPKV.png

接下來,跑一下套用資料增強後的模型:

def aug_img(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.random_flip_left_right(image)
  image = tf.image.random_flip_up_down(image)
  image = tf.image.resize(image, (224,224))
  image = tf.image.random_brightness(image, 70)
  image = tf.image.random_saturation(image, 0.7, 1.5)
  image = tf.image.random_contrast(image, 0.6, 1.4)
  image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
  return image / 255., label

ds_train = train_split.map(
    aug_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(SHUFFLE_SIZE)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

ds_test = test_split.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)

model = tf.keras.Model(inputs=[base.input], outputs=[net])

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

start = timeit.default_timer()
history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=True)

產出:

loss: 5.2807e-04 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.6019 - val_sparse_categorical_accuracy: 0.8422

https://ithelp.ithome.com.tw/upload/images/20210927/20107299l4ikTAA2Gb.png

這次實驗結果顯示,最終的成績並沒有第一個模型好,相比準確度低了3%,但也不致於差到哪裡去,資料增強仍然是一個我實務上常使用的方法。


上一篇
【12】新手容易忽略的 logit 與 loss 之間的搭配
下一篇
【14】如果不做圖片標準化(Normalization)會怎麼樣
系列文
30 天在 Colab 嘗試的 30 個影像分類訓練實驗31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言